{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Keras 训练 mnist 练习"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 导入数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.datasets import mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 看看训练集？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60000"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 看看测试集？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 28, 28)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 尝试训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 添加层：网络包含2个Dense层（全连接层），最后一层是10路 softmax，返回10个数字的概率值，总和为1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:63: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:488: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3626: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from keras import models\n",
    "from keras import layers\n",
    "\n",
    "network = models.Sequential()\n",
    "network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))\n",
    "network.add(layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 添加编译参数：\n",
    "* 损失函数（Loss function）：网络如何朝着正确的方向前进。\n",
    "* 优化器（Optimizer）基于训练数据和损失函数来更新网络的机制。\n",
    "* 在训练和测试中需要监控的指标（metric）：仅关心精度——正确分类的图像比例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/optimizers.py:711: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:2857: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:2861: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "network.compile(optimizer='sgd',\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.  图像预处理\n",
    "图像的像素矩阵转换为数组，\n",
    "图像的单个像素缩放到[0,1]范围的float32。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images = train_images.reshape((60000, 28*28))\n",
    "train_images = train_images.astype('float32')/255\n",
    "\n",
    "test_images = test_images.reshape((10000, 28*28))\n",
    "test_images = test_images.astype('float32')/255"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. 准备标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.utils import to_categorical\n",
    "\n",
    "train_labels = to_categorical(train_labels)\n",
    "test_labels = to_categorical(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. 拟合(fit)\n",
    "epochs：拟合的轮数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "60000/60000 [==============================] - 1s 15us/step - loss: 0.2392 - acc: 0.9340\n",
      "Epoch 2/5\n",
      "60000/60000 [==============================] - 1s 14us/step - loss: 0.2360 - acc: 0.9347\n",
      "Epoch 3/5\n",
      "60000/60000 [==============================] - 1s 14us/step - loss: 0.2330 - acc: 0.9354\n",
      "Epoch 4/5\n",
      "60000/60000 [==============================] - 1s 14us/step - loss: 0.2301 - acc: 0.9363\n",
      "Epoch 5/5\n",
      "60000/60000 [==============================] - 1s 14us/step - loss: 0.2272 - acc: 0.9370\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1523787d0>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network.fit(train_images, train_labels, epochs=5, batch_size=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将拟合的结果应用到训练集上测试，得到精确度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 0s 31us/step\n",
      "test_acc: 0.9372\n"
     ]
    }
   ],
   "source": [
    "test_loss, test_acc = network.evaluate(test_images, test_labels)\n",
    "print('test_acc:', test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
